Code
from pathlib import Pathfrom pathlib import Pathimport altair as alt
import os
import polars as pl
def plot_metric(df, title, y_col, y_label, image_name=None, sort_order='-y'):
"""Generates a bar chart for a given metric and optionally saves it as SVG."""
chart = (
alt.Chart(df)
.mark_bar()
.encode(
x=alt.X("classifier:N", title="Classifier", sort=sort_order),
y=alt.Y(f"{y_col}:Q", title=y_label),
color=alt.Color("classifier:N", legend=alt.Legend(title="Classifier")),
tooltip=["classifier", y_col],
)
.properties(title=title, width=600, height=400)
)
if image_name:
try:
os.makedirs("./images", exist_ok=True)
chart.save(f"./images/{image_name}.svg")
except Exception as e:
print(f"Error saving chart as SVG: {e}")
return chart
def plot_confusion_matrix(reporter, classifier_name, image_name=None):
"""Generates and displays a confusion matrix for a given classifier, optionally saving as SVG."""
cm_df = reporter.get_confusion_matrix(classifier_name)
if cm_df.is_empty():
print(f"No confusion matrix data for {classifier_name}")
return None
base = alt.Chart(cm_df).encode(
x=alt.X("predicted_class:N", title="Predicted"),
y=alt.Y("true_class:N", title="Actual"),
)
heatmap = base.mark_rect().encode(
color=alt.Color("count:Q", title="Count", scale=alt.Scale(scheme='viridis')),
tooltip=["predicted_class", "true_class", "count"],
)
text = base.mark_text(align="center", baseline="middle", fontSize=12).encode(
text=alt.Text("count:Q"),
color=alt.condition(
alt.datum.count > cm_df['count'].max() / 2,
alt.value("white"),
alt.value("black"),
),
)
final_chart = (heatmap + text).properties(
title=f"Confusion Matrix for {classifier_name}", width=400, height=400
)
if image_name:
try:
os.makedirs("./images", exist_ok=True)
final_chart.save(f"./images/{image_name}.svg")
except Exception as e:
print(f"Error saving chart as SVG: {e}")
return final_chart
def plot_biometrics_per_class(reporter, classifier_name, metric, image_name=None):
"""
Generates a dot plot for per-class biometric metrics (EER or AUC)
with a reference line for the mean.
"""
df = reporter.get_per_class_biometrics(classifier_name)
if df.is_empty():
print(f"No per-class biometric data for {classifier_name}")
return None
if metric not in df.columns:
print(f"Metric '{metric}' not found in per-class data.")
return None
# Calculate the mean for the reference line
mean_val = df[metric].mean()
sort_order = "ascending" if metric == "eer" else "descending"
y_label = "EER (Lower is Better)" if metric == "eer" else "AUC (Higher is Better)"
title = f"{metric.upper()} per Class for {classifier_name}"
# Create the main dot plot
points = alt.Chart(df).mark_point(filled=True, size=80).encode(
x=alt.X("class:N", title="Class/Subject", sort=alt.EncodingSortField(field=metric, op="min", order=sort_order)),
y=alt.Y(f"{metric}:Q", title=y_label),
color=alt.Color("class:N", legend=None),
tooltip=["class", metric],
)
# Create the reference line
rule = alt.Chart(pl.DataFrame({'y': [mean_val]})).mark_rule(color='red', strokeDash=[5,5], size=2).encode(y='y:Q')
final_chart = (points + rule).properties(title=title, width=800, height=400)
if image_name:
try:
os.makedirs("./images", exist_ok=True)
final_chart.save(f"./images/{image_name}.svg")
except Exception as e:
print(f"Error saving chart: {e}")
return final_chart
def plot_biometric_heatmap(reporter, metric, image_name=None):
"""
Generates a heatmap to compare a biometric metric (EER or AUC)
across all classifiers and all classes.
"""
df = reporter.dataframes.get("per_class_biometrics")
if df is None or df.is_empty():
print("No per-class biometric data available to generate a heatmap.")
return None
if metric not in df.columns:
print(f"Metric '{metric}' not found for heatmap.")
return None
# For EER, lower is better. For AUC, higher is better.
# We use a color scheme where green is good.
color_scheme = 'redyellowgreen'
reverse_scale = True if metric == 'eer' else False
title = f"Heatmap of {metric.upper()} per Class and Classifier"
chart = alt.Chart(df).mark_rect().encode(
x=alt.X('classifier:N', title='Classifier'),
y=alt.Y('class:N', title='Class/Subject'),
color=alt.Color(f'{metric}:Q',
scale=alt.Scale(scheme=color_scheme, reverse=reverse_scale),
legend=alt.Legend(title=f"{metric.upper()}")),
tooltip=['classifier', 'class', metric]
).properties(
title=title,
width=600,
height=800
)
if image_name:
try:
os.makedirs("./images", exist_ok=True)
chart.save(f"./images/{image_name}.svg")
except Exception as e:
print(f"Error saving chart: {e}")
return chart
def plot_far_frr_curves(reporter, classifier_name, image_name=None):
"""Plots FAR vs. FRR curves for each class for a given classifier."""
df = reporter.get_far_frr_curves(classifier_name)
if df.is_empty():
print(f"No FAR/FRR curve data for {classifier_name}")
return None
curves = (
alt.Chart(df)
.mark_line()
.encode(
x=alt.X("far:Q", title="False Acceptance Rate (FAR)", scale=alt.Scale(type="log", domain=[0.001, 1])),
y=alt.Y("frr:Q", title="False Rejection Rate (FRR)", scale=alt.Scale(type="log", domain=[0.01, 1])),
color=alt.Color("true_class:N", legend=alt.Legend(title="Class/Subject")),
tooltip=["true_class", "far", "frr"],
)
)
line_data = pl.DataFrame({'x': [0.001, 1], 'y': [0.001, 1]})
ref_line = alt.Chart(line_data).mark_line(strokeDash=[5, 5], color='gray').encode(x='x:Q', y='y:Q')
final_chart = (curves + ref_line).properties(
title=f"FAR vs. FRR (DET Curve) for {classifier_name}",
width=600, height=600
).interactive()
if image_name:
try:
os.makedirs("./images", exist_ok=True)
final_chart.save(f"./images/{image_name}.svg")
except Exception as e:
print(f"Error saving chart: {e}")
return final_chartimport plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
from dataclasses import dataclass
from utils.preprocess import SignalPreprocessor
signal_preprocessor = SignalPreprocessor()
@dataclass
class PlotSignalOverlapLabels:
title: str
title1: str
title2: str
label1: str
label2: str
overlap_labels = PlotSignalOverlapLabels(
title="Complete Comparison: Overlapping, Original and Preprocessed",
title1="Original Signal",
title2="Preprocessed Signal",
label1="Original",
label2="Preprocessed",
)
def plot_signal_overlap(
signal, signal_preprocessed, fs, signal_type, labels=overlap_labels
):
# Calculate X axis (time in seconds)
time_axis = np.arange(len(signal)) / fs
# Create a figure with 3 subplots (3 rows, 1 column)
signal_name = signal_type.upper()
fig = make_subplots(
rows=3,
cols=1,
subplot_titles=(
f"Overlapping Signals ({signal_name})",
f"{labels.title1} ({signal_name})",
f"{labels.title2} ({signal_name})",
),
shared_xaxes=True, # Share X axis
vertical_spacing=0.05, # Reduced vertical spacing
)
# --- Add plots to each subplot ---
# 1. Overlapping signals (first subplot)
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal,
mode="lines",
name=labels.label1,
line=dict(color="blue", width=2),
opacity=0.7,
showlegend=True, # Show in legend
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal_preprocessed,
mode="lines",
name=labels.label2,
line=dict(color="red", width=2),
opacity=0.7,
showlegend=True, # Show in legend
),
row=1,
col=1,
)
# 2. Original signal alone (second subplot)
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal,
mode="lines",
name=labels.label1,
line=dict(color="blue", width=2),
opacity=1.0, # Full opacity
showlegend=False, # Avoid duplicating legend
),
row=2,
col=1,
)
# 3. Preprocessed signal alone (third subplot)
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal_preprocessed,
mode="lines",
name=labels.label2,
line=dict(color="red", width=2),
opacity=1.0, # Full opacity
showlegend=False, # Avoid duplicating legend
),
row=3,
col=1,
)
# --- Layout configuration ---
fig.update_layout(
title_text=labels.title,
xaxis3_title="Time (seconds)", # Only the last subplot shows X axis
yaxis_title="Amplitude",
legend=dict(
orientation="h", # Horizontal legend
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
font=dict(size=12),
),
template="plotly_white",
font=dict(size=12),
height=800, # Adjusted height for 3 subplots
)
# Add Y axis titles to lower subplots
fig.update_yaxes(title_text="Amplitude", row=2, col=1)
fig.update_yaxes(title_text="Amplitude", row=3, col=1)
# Show plot
fig.show()
return signal_preprocessed
@dataclass
class PlotLavels:
title: str
name: str
def plot_three_signals(
original_signal,
signal1,
signal2,
titles: tuple[PlotLavels, PlotLavels, PlotLavels],
fs: int,
signal_type,
):
"""
Plots three signals with overlapping and individual subplots.
Parameters:
signal1 (array-like): First signal to plot.
signal2 (array-like): Second signal to plot.
signal3 (array-like): Third signal to plot.
fs (float): Sampling frequency of the signals.
signal_type (str): Type of signal (e.g., "ppg", "ecg"). Default is "ppg".
"""
# Calculate X axis (time in seconds)
time_axis = np.arange(len(original_signal)) / fs
# Create a figure with 4 subplots (4 rows, 1 column)
signal_name = signal_type.upper()
fig = make_subplots(
rows=4,
cols=1,
subplot_titles=(
f"Overlapping Signals ({signal_name})",
titles[0].title,
titles[1].title,
titles[2].title,
),
shared_xaxes=True, # Share X axis
vertical_spacing=0.05, # Reduced vertical spacing
)
# --- Add plots to each subplot ---
# 1. Overlapping signals (first subplot)
fig.add_trace(
go.Scatter(
x=time_axis,
y=original_signal,
mode="lines",
name=titles[0].name,
line=dict(color="blue", width=2),
opacity=0.7,
showlegend=True, # Show in legend
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal1,
mode="lines",
name=titles[1].name,
line=dict(color="green", width=2),
opacity=0.7,
showlegend=True, # Show in legend
),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal2,
mode="lines",
name=titles[2].name,
line=dict(color="red", width=2),
opacity=0.7,
showlegend=True, # Show in legend
),
row=1,
col=1,
)
# 2. Signal 1 alone (second subplot)
fig.add_trace(
go.Scatter(
x=time_axis,
y=original_signal,
mode="lines",
name=titles[0].name,
line=dict(color="blue", width=2),
opacity=1.0, # Full opacity
showlegend=False, # Avoid duplicating legend
),
row=2,
col=1,
)
# 3. Signal 2 alone (third subplot)
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal1,
mode="lines",
name=titles[1].name,
line=dict(color="green", width=2),
opacity=1.0, # Full opacity
showlegend=False, # Avoid duplicating legend
),
row=3,
col=1,
)
# 4. Signal 3 alone (fourth subplot)
fig.add_trace(
go.Scatter(
x=time_axis,
y=signal2,
mode="lines",
name=titles[2].name,
line=dict(color="red", width=2),
opacity=1.0, # Full opacity
showlegend=False, # Avoid duplicating legend
),
row=4,
col=1,
)
# --- Layout configuration ---
fig.update_layout(
title_text="Complete Comparison: Overlapping and Individual Signals",
xaxis4_title="Time (seconds)", # Only the last subplot shows X axis
yaxis_title="Amplitude",
legend=dict(
orientation="h", # Horizontal legend
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
font=dict(size=12),
),
template="plotly_white",
font=dict(size=12),
height=1000, # Adjusted height for 4 subplots
)
# Add Y axis titles to lower subplots
fig.update_yaxes(title_text="Amplitude", row=2, col=1)
fig.update_yaxes(title_text="Amplitude", row=3, col=1)
fig.update_yaxes(title_text="Amplitude", row=4, col=1)
# Show plot
fig.show()
def plot_compararative_full_preprocess_vs_segment(
signal, fs, segment_size_seconds, signal_type
):
title = signal_type.upper()
signal_segment_size = fs * segment_size_seconds
end_segment = signal_segment_size * 2
signal_preprocessed = signal_preprocessor.preprocess_signal(
signal, fs, signal_type=signal_type
)
# Show full signal
plot_signal_overlap(
signal,
signal_preprocessed,
fs,
signal_type,
)
plot_signal_overlap(
signal[signal_segment_size:end_segment],
signal_preprocessed[signal_segment_size:end_segment],
fs,
signal_type,
)
signal_orginal_label = PlotLavels(f"{title} original signal", f"{title} original")
signal_preprocessed_label = PlotLavels(
f"{title} complete signal segment", f"{title} full preprocessing"
)
signal_segment_pereprocessed = signal_preprocessor.preprocess_signal(
signal[signal_segment_size:end_segment], fs, signal_type=signal_type
)
signal_segment_label = PlotLavels(f"{title} individual segment", f"{title} segment")
plot_three_signals(
signal[signal_segment_size:end_segment],
signal_preprocessed[signal_segment_size:end_segment],
signal_segment_pereprocessed,
(signal_orginal_label, signal_preprocessed_label, signal_segment_label),
fs,
signal_type=signal_type,
)
def plot_compararative_preprocess_vs_reference(signal_raw, signal_ref, fs, signal_type):
signal_preprocessed = signal_preprocessor.preprocess_signal(
signal_raw, fs, signal_type=signal_type
)
overlap_ref_labels = PlotSignalOverlapLabels(
title="Complete Comparison: Overlapping, Preprocessed and Reference",
title1="Reference signal",
title2="Preprocessed signal",
label1="Reference",
label2="Preprocessed",
)
# Show full signal
plot_signal_overlap(
signal_ref,
signal_preprocessed,
fs,
signal_type,
labels=overlap_ref_labels,
)import polars as pl
from IPython.display import Markdown, display
def dataframe_to_latex_htmltab(
caption: str,
df: pl.DataFrame,
selected_headers: list[str] | None = None,
show_code: bool = True,
) -> str:
"""
Converts a Polars DataFrame to a LaTeX htmltab environment and optionally displays it
as a syntax-highlighted block in Jupyter (copiable).
Args:
caption (str): Caption and label identifier for the table.
df (pl.DataFrame): The DataFrame to convert.
selected_headers (list[str], optional): Columns to include.
show_code (bool): If True, shows the LaTeX as a Markdown code block.
Returns:
str: LaTeX code for the table.
"""
if selected_headers:
df = df.select(selected_headers)
header = " " + "\n ".join(
[f"\\HTtd{{{col.replace('_', '\\_')}}}" for col in df.columns]
)
body = ""
for row in df.iter_rows():
body += " \\HTtr{\n"
body += (
" "
+ "\n ".join([f"\\HTtd{{{cell}}}" for cell in row])
+ "\n"
)
body += " }\n"
latex_table = f"""\\begin{{table}}[H]
\\centering
\\caption{{Tabla autogenerada}}
\\label{{tab:{caption}}}
\\begin{{htmltab}}
\\begin{{thead}}
\\HTtr{{
{header}
}}
\\end{{thead}}
\\begin{{tbody}}
{body} \\end{{tbody}}
\\end{{htmltab}}
\\end{{table}}"""
if show_code:
# Mostrar como bloque de código copiable en Jupyter
display(Markdown(f"```latex\n{latex_table}\n```"))
return latex_tableimport polars as pl
mimic_path = Path("../Datasets/MIMIC")
user1 = pl.read_csv(mimic_path / "data/mimic_perform_af_001_data.csv")
ppg = user1["PPG"].to_list()
ecg = user1["ECG"].to_list()
resp = user1["resp"].to_list()plot_compararative_full_preprocess_vs_segment(ppg, 125, 30, "ppg")